# Copyright (c) 2017, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA

"""This module contains helper functions."""

import binascii
import decimal
import functools
import inspect
import warnings

from typing import Any, Callable, List, Optional, Union

from .constants import SUPPORTED_TLS_VERSIONS, TLS_CIPHER_SUITES
from .errors import InterfaceError
from .types import EscapeTypes, StrOrBytes

BYTE_TYPES = (bytearray, bytes)
NUMERIC_TYPES = (int, float, decimal.Decimal)


def encode_to_bytes(value: StrOrBytes, encoding: str = "utf-8") -> bytes:
    """Returns an encoded version of the string as a bytes object.

    Args:
        encoding (str): The encoding.

    Resturns:
        bytes: The encoded version of the string as a bytes object.
    """
    return value if isinstance(value, bytes) else value.encode(encoding)


def decode_from_bytes(value: StrOrBytes, encoding: str = "utf-8") -> str:
    """Returns a string decoded from the given bytes.

    Args:
        value (bytes): The value to be decoded.
        encoding (str): The encoding.

    Returns:
        str: The value decoded from bytes.
    """
    return value.decode(encoding) if isinstance(value, bytes) else value


def get_item_or_attr(obj: object, key: str) -> Any:
    """Get item from dictionary or attribute from object.

    Args:
        obj (object): Dictionary or object.
        key (str): Key.

    Returns:
        object: The object for the provided key.
    """
    return obj[key] if isinstance(obj, dict) else getattr(obj, key)


def escape(*args: EscapeTypes) -> Union[EscapeTypes, List[EscapeTypes]]:
    """Escapes special characters as they are expected to be when MySQL
    receives them.
    As found in MySQL source mysys/charset.c

    Args:
        value (object): Value to be escaped.

    Returns:
        str: The value if not a string, or the escaped string.
    """

    def _escape(value: EscapeTypes) -> EscapeTypes:
        """Escapes special characters."""
        if value is None:
            return value
        if isinstance(value, NUMERIC_TYPES):
            return value
        if isinstance(value, (bytes, bytearray)):
            value = value.replace(b"\\", b"\\\\")
            value = value.replace(b"\n", b"\\n")
            value = value.replace(b"\r", b"\\r")
            value = value.replace(b"\047", b"\134\047")  # single quotes
            value = value.replace(b"\042", b"\134\042")  # double quotes
            value = value.replace(b"\032", b"\134\032")  # for Win32
        else:
            value = value.replace("\\", "\\\\")
            value = value.replace("\n", "\\n")
            value = value.replace("\r", "\\r")
            value = value.replace("\047", "\134\047")  # single quotes
            value = value.replace("\042", "\134\042")  # double quotes
            value = value.replace("\032", "\134\032")  # for Win32
        return value

    if len(args) > 1:
        return [_escape(arg) for arg in args]
    return _escape(args[0])


def quote_identifier(identifier: str, sql_mode: str = "") -> str:
    """Quote the given identifier with backticks, converting backticks (`)
    in the identifier name with the correct escape sequence (``) unless the
    identifier is quoted (") as in sql_mode set to ANSI_QUOTES.

    Args:
        identifier (str): Identifier to quote.

    Returns:
        str: Returns string with the identifier quoted with backticks.
    """
    if sql_mode == "ANSI_QUOTES":
        quoted = identifier.replace('"', '""')
        return f'"{quoted}"'
    quoted = identifier.replace("`", "``")
    return f"`{quoted}`"


def deprecated(version: Optional[str] = None, reason: Optional[str] = None) -> Callable:
    """This is a decorator used to mark functions as deprecated.

    Args:
        version (Optional[string]): Version when was deprecated.
        reason (Optional[string]): Reason or extra information to be shown.

    Returns:
        Callable: A decorator used to mark functions as deprecated.

    Usage:

    .. code-block:: python

       from mysqlx.helpers import deprecated

       @deprecated('8.0.12', 'Please use other_function() instead')
       def deprecated_function(x, y):
           return x + y
    """

    def decorate(func: Callable) -> Callable:
        """Decorate function."""

        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Callable:
            """Wrapper function.

            Args:
                *args: Variable length argument list.
                **kwargs: Arbitrary keyword arguments.
            """
            message = [f"'{func.__name__}' is deprecated"]
            if version:
                message.append(f" since version {version}")
            if reason:
                message.append(f". {reason}")
            frame = inspect.currentframe().f_back
            warnings.warn_explicit(
                "".join(message),
                category=DeprecationWarning,
                filename=inspect.getfile(frame.f_code),
                lineno=frame.f_lineno,
            )
            return func(*args, **kwargs)

        return wrapper

    return decorate


def iani_to_openssl_cs_name(
    tls_version: str, cipher_suites_names: List[str]
) -> List[str]:
    """Translates a cipher suites names list; from IANI names to OpenSSL names.

    Args:
        TLS_version (str): The TLS version to look at for a translation.
        cipher_suite_names (list): A list of cipher suites names.

    Returns:
        List[str]: List of translated names.
    """
    translated_names = []

    cipher_suites = {}  # TLS_CIPHER_SUITES[TLS_version]

    # Find the previews TLS versions of the given on TLS_version
    for index in range(SUPPORTED_TLS_VERSIONS.index(tls_version) + 1):
        cipher_suites.update(TLS_CIPHER_SUITES[SUPPORTED_TLS_VERSIONS[index]])

    for name in cipher_suites_names:
        if "-" in name:
            translated_names.append(name)
        elif name in cipher_suites:
            translated_names.append(cipher_suites[name])
        else:
            raise InterfaceError(
                f"The '{name}' in cipher suites is not a valid cipher suite"
            )
    return translated_names


def hexlify(data: bytes) -> str:
    """Return the hexadecimal representation of the binary data.

    Args:
        data (bytes): The binary data.

    Returns:
        str: The decoded hexadecimal representation of data.
    """
    return binascii.hexlify(data).decode("utf-8")
